{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training included SpaceNet models with the `solaris` Python API\n",
    "\n",
    "We've included a number of SpaceNet models with `solaris`, including pre-trained model weights. You can find more information about your model choices [here](../pretrained_models.html) and the original competitors' code for the models [here](https://github.com/spacenetchallenge/spacenet_off_nadir_solutions).\n",
    "\n",
    "For this tutorial we'll walk through training a model using XD_XD's SpaceNet 4 model. We'll use the config file for that model, which you can find [here](https://github.com/CosmiQ/solaris/blob/master/solaris/nets/configs/xdxd_spacenet4.yml).\n",
    "\n",
    "You'll also need to [create training masks](api_masks_tutorial.ipynb) and [create the image reference files](creating_im_reference_csvs.ipynb) before you start.\n",
    "\n",
    "Once you've completed those steps, you can get down to model training! First, let's load in the configuration:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model_name': 'xdxd_spacenet4',\n",
       " 'model_path': None,\n",
       " 'train': False,\n",
       " 'infer': True,\n",
       " 'pretrained': True,\n",
       " 'nn_framework': 'torch',\n",
       " 'batch_size': 12,\n",
       " 'data_specs': {'width': 512,\n",
       "  'height': 512,\n",
       "  'image_type': 'zscore',\n",
       "  'rescale': False,\n",
       "  'rescale_minima': 'auto',\n",
       "  'rescale_maxima': 'auto',\n",
       "  'channels': 4,\n",
       "  'label_type': 'mask',\n",
       "  'is_categorical': False,\n",
       "  'mask_channels': 1,\n",
       "  'val_holdout_frac': 0.2,\n",
       "  'data_workers': None},\n",
       " 'training_data_csv': '/path/to/training_df.csv',\n",
       " 'validation_data_csv': None,\n",
       " 'inference_data_csv': '/path/to/test_df.csv',\n",
       " 'training_augmentation': {'augmentations': {'DropChannel': {'idx': 3,\n",
       "    'axis': 2},\n",
       "   'HorizontalFlip': {'p': 0.5},\n",
       "   'RandomRotate90': {'p': 0.5},\n",
       "   'RandomCrop': {'height': 512, 'width': 512, 'p': 1.0},\n",
       "   'Normalize': {'mean': [0.006479, 0.009328, 0.01123],\n",
       "    'std': [0.004986, 0.004964, 0.00495],\n",
       "    'max_pixel_value': 65535.0,\n",
       "    'p': 1.0}},\n",
       "  'p': 1.0,\n",
       "  'shuffle': True},\n",
       " 'validation_augmentation': {'augmentations': {'DropChannel': {'idx': 3,\n",
       "    'axis': 2},\n",
       "   'CenterCrop': {'height': 512, 'width': 512, 'p': 1.0},\n",
       "   'Normalize': {'mean': [0.006479, 0.009328, 0.01123],\n",
       "    'std': [0.004986, 0.004964, 0.00495],\n",
       "    'max_pixel_value': 65535.0,\n",
       "    'p': 1.0}},\n",
       "  'p': 1.0},\n",
       " 'inference_augmentation': {'augmentations': {'DropChannel': {'idx': 3,\n",
       "    'axis': 2,\n",
       "    'p': 1.0},\n",
       "   'Normalize': {'mean': [0.006479, 0.009328, 0.01123],\n",
       "    'std': [0.004986, 0.004964, 0.00495],\n",
       "    'max_pixel_value': 65535.0,\n",
       "    'p': 1.0}},\n",
       "  'p': 1.0},\n",
       " 'training': {'epochs': 60,\n",
       "  'steps_per_epoch': None,\n",
       "  'optimizer': 'Adam',\n",
       "  'lr': 0.0001,\n",
       "  'opt_args': None,\n",
       "  'loss': {'bcewithlogits': None, 'jaccard': None},\n",
       "  'loss_weights': {'bcewithlogits': 10, 'jaccard': 2.5},\n",
       "  'metrics': {'training': None, 'validation': None},\n",
       "  'checkpoint_frequency': 10,\n",
       "  'callbacks': {'model_checkpoint': {'filepath': 'xdxd_best.pth',\n",
       "    'monitor': 'val_loss'}},\n",
       "  'model_dest_path': 'xdxd.pth',\n",
       "  'verbose': True},\n",
       " 'inference': {'window_step_size_x': None,\n",
       "  'window_step_size_y': None,\n",
       "  'output_dir': 'inference_out/'}}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import solaris as sol\n",
    "\n",
    "config = sol.utils.config.parse('/path/to/xdxd_spacenet4.yml')\n",
    "config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, the YAML gets parsed into a set of nested dictionaries by `solaris`. Relevant pieces of that config then get read during training.\n",
    "\n",
    "Let's assume that all of the paths in that config are correct. Now, to run training, you can run the following line:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = sol.nets.train.Trainer(config)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's all there is to it! With just four lines of Python code, you can train a model for geospatial ML!\n",
    "\n",
    "If you wish to use a custom ML model not provided by `solaris`, check out [the custom model training tutorial](api_training_custom.ipynb)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "solaris",
   "language": "python",
   "name": "solaris"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
